import torch
from .base import BaseModel
from ..smp import *
from ..dataset import DATASET_TYPE
from torchvision import transforms
from transformers import AutoTokenizer, AutoModel

import io
import random
import numpy as np
import math


def get_frame_indices(
    num_frames, vlen, sample="rand", fix_start=None, input_fps=1, max_num_frames=-1
):
    if sample in ["rand", "middle"]:
        acc_samples = min(num_frames, vlen)
        # split the video into `acc_samples` intervals, and sample from each interval.
        intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
        ranges = []
        for idx, interv in enumerate(intervals[:-1]):
            ranges.append((interv, intervals[idx + 1] - 1))
        if sample == "rand":
            try:
                frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
            except:
                frame_indices = np.random.permutation(vlen)[:acc_samples]
                frame_indices.sort()
                frame_indices = list(frame_indices)
        elif fix_start is not None:
            frame_indices = [x[0] + fix_start for x in ranges]
        elif sample == "middle":
            frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
        else:
            raise NotImplementedError

        if len(frame_indices) < num_frames:  # padded with last frame
            padded_frame_indices = [frame_indices[-1]] * num_frames
            padded_frame_indices[: len(frame_indices)] = frame_indices
            frame_indices = padded_frame_indices

    elif "fps" in sample:  # fps0.5, sequentially sample frames at 0.5 fps
        output_fps = float(sample[3:])
        duration = float(vlen) / input_fps
        delta = (
            1 / output_fps
        )  # gap between frames, this is also the clip length each frame represents
        frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
        frame_indices = np.around(frame_seconds * input_fps).astype(int)
        frame_indices = [e for e in frame_indices if e < vlen]
        if max_num_frames > 0 and len(frame_indices) > max_num_frames:
            frame_indices = frame_indices[:max_num_frames]
            # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
    elif "interval" in sample:
        if num_frames == 1:
            frame_indices = [random.randint(0, vlen - 1)]
        else:
            # transform FPS
            interval = 8
            clip_length = num_frames * interval * input_fps / 30
            max_idx = max(vlen - clip_length, 0)
            start_idx = random.uniform(0, max_idx)
            end_idx = start_idx + clip_length - 1

            frame_indices = torch.linspace(start_idx, end_idx, num_frames)
            frame_indices = torch.clamp(frame_indices, 0, vlen - 1).long().tolist()
    else:
        raise ValueError
    return frame_indices


def get_frame_indices_start_end(num_frames, vlen, fps, start_time, end_time):
    start_idx = (
        max(int(fps * start_time), 0)
        if start_time is not None and not math.isnan(start_time)
        else 0
    )
    end_idx = (
        min(int(fps * end_time), vlen)
        if end_time is not None and not math.isnan(end_time)
        else vlen
    )
    clip_len = end_idx - start_idx

    acc_samples = min(num_frames, clip_len)
    # split the video into `acc_samples` intervals, and sample from each interval.
    intervals = np.linspace(start=start_idx, stop=end_idx, num=acc_samples + 1).astype(
        int
    )
    ranges = []
    for idx, interv in enumerate(intervals[:-1]):
        ranges.append((interv, intervals[idx + 1] - 1))

    try:
        frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
    except:
        frame_indices = np.random.permutation(list(range(start_idx, end_idx)))[
            :acc_samples
        ]
        frame_indices.sort()
        frame_indices = list(frame_indices)

    if len(frame_indices) < num_frames:  # padded with last frame
        padded_frame_indices = [frame_indices[-1]] * num_frames
        padded_frame_indices[: len(frame_indices)] = frame_indices
        frame_indices = padded_frame_indices

    return frame_indices


def read_frames_decord(
    video_path,
    width=None,
    height=None,
    num_frames=8,
    sample="rand",
    fix_start=None,
    max_num_frames=-1,
    start_time=None,
    end_time=None,
):
    import decord

    decord.bridge.set_bridge("torch")
    if video_path.lower().endswith(".webm"):
        # a workaround for webm, large/auto num_threads will cause error.
        num_threads = 2
    else:
        num_threads = 0

    if width is not None and height is not None:
        video_reader = decord.VideoReader(
            video_path, width=width, height=height, num_threads=num_threads
        )
    else:
        video_reader = decord.VideoReader(video_path, num_threads=num_threads)
    vlen = len(video_reader)
    fps = video_reader.get_avg_fps()
    if start_time and end_time:
        frame_indices = get_frame_indices_start_end(
            num_frames, vlen, fps, start_time, end_time
        )
    else:
        frame_indices = get_frame_indices(
            num_frames,
            vlen,
            sample=sample,
            fix_start=fix_start,
            input_fps=fps,
            max_num_frames=max_num_frames,
        )
    frames = video_reader.get_batch(frame_indices)
    if isinstance(frames, torch.Tensor):
        frames = frames.numpy()  # (T, H, W, C), torch.uint8
    else:
        print(frames.shape)
        frames = frames.asnumpy()
    timestamp = {
        "num_frames": len(frame_indices),
        "timestamp": ", ".join([str(round(f / fps, 1)) for f in frame_indices]),
    }
    return frames, timestamp


class mPLUG_Owl3(BaseModel):
    # No separate model module is required, but the dependencies must be met.
    # https://github.com/X-PLUG/mPLUG-Owl/blob/main/mPLUG-Owl3/requirements.txt
    INSTALL_REQ = True
    INTERLEAVE = True
    INSTALL_REQ_TXT = (
        "https://github.com/X-PLUG/mPLUG-Owl/blob/main/mPLUG-Owl3/requirements.txt"
    )

    def __init__(self, model_path=None, **kwargs):
        assert model_path is not None
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)

        self.model = AutoModel.from_pretrained(
            model_path,
            attn_implementation="sdpa",
            torch_dtype=torch.half,
            trust_remote_code=True,
        )
        self.model.eval().cuda()
        self.processor = self.model.init_processor(self.tokenizer)
        self.logger = get_logger("mPLUG_Owl3")
        if self.INSTALL_REQ:
            self.logger.info(
                f"Please remember to meet the requirements first\n"
                f"Here: {self.INSTALL_REQ_TXT}"
            )

    def use_custom_prompt(self, dataset):
        assert dataset is not None
        if listinstr(["MMMU"], dataset):
            return False
        if listinstr(["MVBench", "MMVet"], dataset):
            return True
        return False

    def save_video_into_images(self, line, num_frames=16, dataset_class=None):
        video_url = {
            "video": osp.join(line["prefix"], line["video"]),
            "num_frames": num_frames,
            "bound": line.get("bound", None),
        }
        if osp.isdir(video_url["video"]):
            frame_paths = []
            max_frame = len(os.listdir(video_url["video"]))
            fps = 3
            if video_url["bound"]:
                start, end = line["start"], line["end"]
            else:
                start, end = -100000, 100000
            start_idx = max(1, round(start * fps))
            end_idx = min(round(end * fps), max_frame)
            seg_size = float(end_idx - start_idx) / num_frames
            frame_indices = np.array(
                [
                    int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
                    for idx in range(num_frames)
                ]
            )

            for frame_index in frame_indices:
                img = os.path.join(video_url["video"], f"{frame_index:05d}.jpg")
                frame_paths.append(img)

            return frame_paths

        if isinstance(video_url, dict):
            if video_url["bound"]:
                start_time = line["start"]
                end_time = line["end"]
            else:
                start_time = None
                end_time = None
            num_frames = video_url.get("num_frames", num_frames)
            video_url = video_url["video"]
        else:
            start_time = None
            end_time = None
            video_url = str(video_url)

        if not osp.exists(video_url):  # for MVBench_MP4
            video_url = osp.join(dataset_class.data_root, video_url)
        video, timestamp = read_frames_decord(
            video_url,
            num_frames=num_frames,
            sample="middle",
            start_time=start_time,
            end_time=end_time,
        )

        to_pil = transforms.ToPILImage()
        frames = [to_pil(video[ti]) for ti in range(video.shape[0])]
        lmu_root = LMUDataRoot()
        frame_root = osp.join(
            lmu_root, "images", dataset_class.dataset_name, "mplug_owl3"
        )
        frame_root = osp.join(frame_root, video_url.split("/")[-1].split(".")[0])
        os.makedirs(frame_root, exist_ok=True)
        frame_tmpl = "frame-{}-of-{}.jpg"
        frame_paths = [
            osp.join(frame_root, frame_tmpl.format(i, num_frames))
            for i in range(1, num_frames + 1)
        ]
        for im, pth in zip(frames, frame_paths):
            if not osp.exists(pth):
                im.save(pth)

        return frame_paths

    # Currently same to mPLUG_Owl2
    def build_prompt(self, line, dataset=None, num_frames=16, video_llm=False):
        if not isinstance(dataset, str):
            dataset_class = dataset
            dataset = dataset_class.dataset_name
        assert dataset is None or isinstance(dataset, str)
        assert self.use_custom_prompt(dataset)
        if dataset_class.MODALITY == "VIDEO":
            if listinstr(["MVBench"], dataset):
                tgt_path = self.save_video_into_images(line, num_frames, dataset_class)
            else:
                tgt_path = dataset_class.save_video_into_images(line, num_frames)
            if type(line["candidates"]) is not list:
                line["candidates"] = eval(line["candidates"])
            for idx, c in enumerate(line["candidates"]):
                line[chr(ord("A") + idx)] = c
        else:
            tgt_path = self.dump_image(line, dataset)
        question = line["question"]
        if dataset == "MMVet":
            prompt = question + "\nAnswer the question directly. "
        elif listinstr(["MCQ", "Video-MCQ"], DATASET_TYPE(dataset)):
            options = {
                cand: line[cand]
                for cand in string.ascii_uppercase
                if cand in line and not pd.isna(line[cand])
            }
            options_prompt = ""
            for key, item in options.items():
                options_prompt += f"{key}. {item}\n"

            hint = (
                line["hint"] if ("hint" in line and not pd.isna(line["hint"])) else None
            )
            prompt = f"Hint: {hint}\n" if hint is not None else ""
            prompt += f"{question}\n"
            prompt += (
                f"{options_prompt}\nAnswer with the option’s letter from the given choices directly. "
                if len(options)
                else "Answer the question directly. "
            )
        else:
            raise NotImplementedError

        message = [dict(type="text", value=prompt)]
        message.extend([dict(type="image", value=s) for s in tgt_path])
        return message

    def preproc_image(self, fname, dataset=None):
        from PIL import Image

        image = Image.open(fname).convert("RGB")
        # resize to max_size
        max_size = 448 * 16
        if max(image.size) > max_size and not listinstr(["MVBench"], dataset):
            w, h = image.size
            if w > h:
                new_w = max_size
                new_h = int(h * max_size / w)
            else:
                new_h = max_size
                new_w = int(w * max_size / h)
            image = image.resize((new_w, new_h), resample=Image.BICUBIC)
        return image

    def generate_inner(self, message, dataset=None):
        num_images = len([x for x in message if x["type"] == "image"])
        assert num_images >= 0

        images = []
        prompt_full = ""

        for msg in message:
            if msg["type"] == "image":
                images.append(msg["value"])
                prompt_full += "<|image|>"
            elif msg["type"] == "text":
                prompt_full += msg["value"]

        needed_messages = [
            {"role": "user", "content": prompt_full},
            {"role": "assistant", "content": ""},
        ]

        images = [self.preproc_image(fname, dataset) for fname in images]

        inputs = self.processor(
            needed_messages, images=images, videos=None, cut_enable=False
        )

        inputs.to("cuda")
        if listinstr(["MVBench"], dataset):
            inputs.update(
                {
                    "tokenizer": self.tokenizer,
                    "max_new_tokens": 100,
                    "decode_text": True,
                    "do_sample": True,
                    "top_k": 1,
                }
            )
        else:
            inputs.update(
                {
                    "tokenizer": self.tokenizer,
                    "max_new_tokens": 1024,
                    "decode_text": True,
                }
            )

        g = self.model.generate(**inputs)
        return g[0]
